{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Задание 5.2 - Word2Vec with Negative Sampling\n",
    "\n",
    "В этом задании мы натренируем свои версию word vectors с negative sampling на том же небольшом датасете.\n",
    "\n",
    "\n",
    "Несмотря на то, что основная причина использования Negative Sampling - улучшение скорости тренировки word2vec, в нашем игрушечном примере мы **не требуем** улучшения производительности. Мы используем negative sampling просто как дополнительное упражнение для знакомства с PyTorch.\n",
    "\n",
    "Перед запуском нужно запустить скрипт `download_data.sh`, чтобы скачать данные.\n",
    "\n",
    "Датасет и модель очень небольшие, поэтому это задание можно выполнить и без GPU.\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torch.utils.data import Dataset\n",
    "\n",
    "from torchvision import transforms\n",
    "\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# We'll use Principal Component Analysis (PCA) to visualize word vectors,\n",
    "# so make sure you install dependencies from requirements.txt!\n",
    "from sklearn.decomposition import PCA \n",
    "\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "\n",
    "class StanfordTreeBank:\n",
    "    '''\n",
    "    Wrapper for accessing Stanford Tree Bank Dataset\n",
    "    https://nlp.stanford.edu/sentiment/treebank.html\n",
    "    \n",
    "    Parses dataset, gives each token and index and provides lookups\n",
    "    from string token to index and back\n",
    "    \n",
    "    Allows to generate random context with sampling strategy described in\n",
    "    word2vec paper:\n",
    "    https://papers.nips.cc/paper/5021-distributed-representations-of-words-and-phrases-and-their-compositionality.pdf\n",
    "    '''\n",
    "    def __init__(self):\n",
    "        self.index_by_token = {} # map of string -> token index\n",
    "        self.token_by_index = []\n",
    "\n",
    "        self.sentences = []\n",
    "\n",
    "        self.token_freq = {}\n",
    "        \n",
    "        self.token_reject_by_index = None\n",
    "\n",
    "    def load_dataset(self, folder):\n",
    "        filename = os.path.join(folder, \"datasetSentences.txt\")\n",
    "\n",
    "        with open(filename, \"r\", encoding=\"latin1\") as f:\n",
    "            l = f.readline() # skip the first line\n",
    "            \n",
    "            for l in f:\n",
    "                splitted_line = l.strip().split()\n",
    "                words = [w.lower() for w in splitted_line[1:]] # First one is a number\n",
    "                    \n",
    "                self.sentences.append(words)\n",
    "                for word in words:\n",
    "                    if word in self.token_freq:\n",
    "                        self.token_freq[word] +=1 \n",
    "                    else:\n",
    "                        index = len(self.token_by_index)\n",
    "                        self.token_freq[word] = 1\n",
    "                        self.index_by_token[word] = index\n",
    "                        self.token_by_index.append(word)\n",
    "        self.compute_token_prob()\n",
    "                        \n",
    "    def compute_token_prob(self):\n",
    "        words_count = np.array([self.token_freq[token] for token in self.token_by_index])\n",
    "        words_freq = words_count / np.sum(words_count)\n",
    "        \n",
    "        # Following sampling strategy from word2vec paper\n",
    "        self.token_reject_by_index = 1- np.sqrt(1e-5/words_freq)\n",
    "    \n",
    "    def check_reject(self, word):\n",
    "        return np.random.rand() > self.token_reject_by_index[self.index_by_token[word]]\n",
    "        \n",
    "    def get_random_context(self, context_length=5):\n",
    "        \"\"\"\n",
    "        Returns tuple of center word and list of context words\n",
    "        \"\"\"\n",
    "        sentence_sampled = []\n",
    "        while len(sentence_sampled) <= 2:\n",
    "            sentence_index = np.random.randint(len(self.sentences)) \n",
    "            sentence = self.sentences[sentence_index]\n",
    "            sentence_sampled = [word for word in sentence if self.check_reject(word)]\n",
    "    \n",
    "        center_word_index = np.random.randint(len(sentence_sampled))\n",
    "        \n",
    "        words_before = sentence_sampled[max(center_word_index - context_length//2,0):center_word_index]\n",
    "        words_after = sentence_sampled[center_word_index+1: center_word_index+1+context_length//2]\n",
    "        \n",
    "        return sentence_sampled[center_word_index], words_before+words_after\n",
    "    \n",
    "    def num_tokens(self):\n",
    "        return len(self.token_by_index)\n",
    "        \n",
    "data = StanfordTreeBank()\n",
    "data.load_dataset(\"./stanfordSentimentTreebank/\")\n",
    "\n",
    "print(\"Num tokens:\", data.num_tokens())\n",
    "for i in range(5):\n",
    "    center_word, other_words = data.get_random_context(5)\n",
    "    print(center_word, other_words)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Dataset для Negative Sampling должен быть немного другим\n",
    "\n",
    "Как и прежде, Dataset должен сгенерировать много случайных контекстов и превратить их в сэмплы для тренировки.\n",
    "\n",
    "Здесь мы реализуем прямой проход модели сами, поэтому выдавать данные можно в удобном нам виде.\n",
    "Напоминаем, что в случае negative sampling каждым сэмплом является:\n",
    "- вход: слово в one-hot представлении\n",
    "- выход: набор из одного целевого слова и K других случайных слов из словаря.\n",
    "Вместо softmax + cross-entropy loss, сеть обучается через binary cross-entropy loss - то есть, предсказывает набор бинарных переменных, для каждой из которых функция ошибки считается независимо.\n",
    "\n",
    "Для целевого слова бинарное предсказание должно быть позитивным, а для K случайных слов - негативным.\n",
    "\n",
    "Из набора слово-контекст создается N сэмплов (где N - количество слов в контексте), в каждом из них K+1 целевых слов, для только одного из которых предсказание должно быть позитивным.\n",
    "Например, для K=2:\n",
    "\n",
    "Слово: `orders` и контекст: `['love', 'nicest', 'to', '50-year']` создадут 4 сэмпла:\n",
    "- input: `orders`, target: `[love: 1, any: 0, rose: 0]`\n",
    "- input: `orders`, target: `[nicest: 1, fool: 0, grass: 0]`\n",
    "- input: `orders`, target: `[to: 1, -: 0, the: 0]`\n",
    "- input: `orders`, target: `[50-year: 1, ?: 0, door: 0]`\n",
    "\n",
    "Все слова на входе и на выходе закодированы через one-hot encoding, с размером вектора равным количеству токенов."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_negative_samples = 10\n",
    "\n",
    "class Word2VecNegativeSampling(Dataset):\n",
    "    '''\n",
    "    PyTorch Dataset for Word2Vec with Negative Sampling.\n",
    "    Accepts StanfordTreebank as data and is able to generate dataset based on\n",
    "    a number of random contexts\n",
    "    '''\n",
    "    def __init__(self, data, num_negative_samples, num_contexts=30000):\n",
    "        '''\n",
    "        Initializes Word2VecNegativeSampling, but doesn't generate the samples yet\n",
    "        (for that, use generate_dataset)\n",
    "        Arguments:\n",
    "        data - StanfordTreebank instace\n",
    "        num_negative_samples - number of negative samples to generate in addition to a positive one\n",
    "        num_contexts - number of random contexts to use when generating a dataset\n",
    "        '''\n",
    "        \n",
    "        # TODO: Implement what you need for other methods!\n",
    "    \n",
    "    def generate_dataset(self):\n",
    "        '''\n",
    "        Generates dataset samples from random contexts\n",
    "        Note: there will be more samples than contexts because every context\n",
    "        can generate more than one sample\n",
    "        '''\n",
    "        # TODO: Implement generating the dataset\n",
    "        # You should sample num_contexts contexts from the data and turn them into samples\n",
    "        # Note you will have several samples from one context\n",
    "        \n",
    "    def __len__(self):\n",
    "        '''\n",
    "        Returns total number of samples\n",
    "        '''\n",
    "        # TODO: Return the number of samples\n",
    "\n",
    "    \n",
    "    def __getitem__(self, index):\n",
    "        '''\n",
    "        Returns i-th sample\n",
    "        \n",
    "        Return values:\n",
    "        input_vector - index of the input word (not torch.Tensor!)\n",
    "        output_indices - torch.Tensor of indices of the target words. Should be 1+num_negative_samples.\n",
    "        output_target - torch.Tensor with float targets for the training. Should be the same size as output_indices\n",
    "                        and have 1 for the context word and 0 everywhere else\n",
    "        '''\n",
    "        # TODO: Generate tuple of 3 return arguments for i-th sample\n",
    "\n",
    "dataset = Word2VecNegativeSampling(data, num_negative_samples, 10)\n",
    "dataset.generate_dataset()\n",
    "input_vector, output_indices, output_target = dataset[0]\n",
    "\n",
    "print(\"Sample - input: %s, output indices: %s, output target: %s\" % (int(input_vector), output_indices, output_target)) # target should be able to convert to int\n",
    "assert isinstance(output_indices, torch.Tensor)\n",
    "assert output_indices.shape[0] == num_negative_samples+1\n",
    "\n",
    "assert isinstance(output_target, torch.Tensor)\n",
    "assert output_target.shape[0] == num_negative_samples+1\n",
    "assert torch.sum(output_target) == 1.0"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Создаем модель\n",
    "\n",
    "Для нашей задачи нам придется реализовать свою собственную PyTorch модель.\n",
    "Эта модель реализует свой собственный прямой проход (forward pass), который получает на вход индекс входного слова и набор индексов для выходных слов. \n",
    "\n",
    "Как всегда, на вход приходит не один сэмпл, а целый batch.  \n",
    "Напомним, что цели улучшить скорость тренировки у нас нет, достаточно чтобы она сходилась."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create the usual PyTorch structures\n",
    "dataset = Word2VecNegativeSampling(data, num_negative_samples, 30000)\n",
    "dataset.generate_dataset()\n",
    "\n",
    "# As before, we'll be training very small word vectors!\n",
    "wordvec_dim = 10\n",
    "\n",
    "class Word2VecNegativeSamples(nn.Module):\n",
    "    def __init__(self, num_tokens):\n",
    "        super(Word2VecNegativeSamples, self).__init__()\n",
    "        self.input = nn.Linear(num_tokens, 10, bias=False)\n",
    "        self.ouput = nn.Linear(10, num_tokens, bias=False)\n",
    "        \n",
    "    def forward(self, input_index_batch, output_indices_batch):\n",
    "        '''\n",
    "        Implements forward pass with negative sampling\n",
    "        \n",
    "        Arguments:\n",
    "        input_index_batch - Tensor of ints, shape: (batch_size, ), indices of input words in the batch\n",
    "        output_indices_batch - Tensor if ints, shape: (batch_size, num_negative_samples+1),\n",
    "                                indices of the target words for every sample\n",
    "                                \n",
    "        Returns:\n",
    "        predictions - Tensor of floats, shape: (batch_size, um_negative_samples+1)\n",
    "        '''\n",
    "        results = []\n",
    "        \n",
    "        # TODO Implement forward pass\n",
    "        # Hint: You can use for loop to go over all samples on the batch,\n",
    "        # run every sample indivisually and then use\n",
    "        # torch.stack or torch.cat to produce the final result\n",
    "    \n",
    "nn_model = Word2VecNegativeSamples(data.num_tokens())\n",
    "nn_model.type(torch.FloatTensor)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def extract_word_vectors(nn_model):\n",
    "    '''\n",
    "    Extracts word vectors from the model\n",
    "    \n",
    "    Returns:\n",
    "    input_vectors: torch.Tensor with dimensions (num_tokens, num_dimensions)\n",
    "    output_vectors: torch.Tensor with dimensions (num_tokens, num_dimensions)\n",
    "    '''\n",
    "    # TODO: Implement extracting word vectors from param weights\n",
    "    # return tuple of input vectors and output vectos \n",
    "\n",
    "untrained_input_vectors, untrained_output_vectors = extract_word_vectors(nn_model)\n",
    "assert untrained_input_vectors.shape == (data.num_tokens(), wordvec_dim)\n",
    "assert untrained_output_vectors.shape == (data.num_tokens(), wordvec_dim)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_neg_sample(model, dataset, train_loader, optimizer, scheduler, num_epochs):    \n",
    "    '''\n",
    "    Trains word2vec with negative samples on and regenerating dataset every epoch\n",
    "    \n",
    "    Returns:\n",
    "    loss_history, train_history\n",
    "    '''\n",
    "    loss = nn.BCEWithLogitsLoss().type(torch.FloatTensor)\n",
    "    loss_history = []\n",
    "    train_history = []\n",
    "    for epoch in range(num_epochs):\n",
    "        model.train() # Enter train mode\n",
    "        \n",
    "        dataset.generate_dataset()\n",
    "        \n",
    "        # TODO: Implement training using negative samples\n",
    "        # You can estimate accuracy by comparing prediction values with 0\n",
    "        # And don't forget to step the scheduler!\n",
    "        \n",
    "        print(\"Average loss: %f, Train accuracy: %f\" % (ave_loss, train_accuracy))\n",
    "        \n",
    "    return loss_history, train_history"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Ну и наконец тренировка!\n",
    "\n",
    "Добейтесь значения ошибки меньше **0.25**."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Finally, let's train the model!\n",
    "\n",
    "# TODO: We use placeholder values for hyperparameters - you will need to find better values!\n",
    "optimizer = optim.SGD(nn_model.parameters(), lr=1e-1, weight_decay=0)\n",
    "scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)\n",
    "train_loader = torch.utils.data.DataLoader(dataset, batch_size=20)\n",
    "\n",
    "loss_history, train_history = train_neg_sample(nn_model, dataset, train_loader, optimizer, scheduler, 10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Visualize training graphs\n",
    "plt.subplot(211)\n",
    "plt.plot(train_history)\n",
    "plt.subplot(212)\n",
    "plt.plot(loss_history)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Визуализируем вектора для разного вида слов до и после тренировки\n",
    "\n",
    "Как и ранее, в случае успешной тренировки вы должны увидеть как вектора слов разных типов (например, знаков препинания, предлогов и остальных)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trained_input_vectors, trained_output_vectors = extract_word_vectors(nn_model)\n",
    "assert trained_input_vectors.shape == (data.num_tokens(), wordvec_dim)\n",
    "assert trained_output_vectors.shape == (data.num_tokens(), wordvec_dim)\n",
    "\n",
    "def visualize_vectors(input_vectors, output_vectors, title=''):\n",
    "    full_vectors = torch.cat((input_vectors, output_vectors), 0)\n",
    "    wordvec_embedding = PCA(n_components=2).fit_transform(full_vectors)\n",
    "\n",
    "    # Helpful words form CS244D example\n",
    "    # http://cs224d.stanford.edu/assignment1/index.html\n",
    "    visualize_words = {'green': [\"the\", \"a\", \"an\"], \n",
    "                      'blue': [\",\", \".\", \"?\", \"!\", \"``\", \"''\", \"--\"], \n",
    "                      'brown': [\"good\", \"great\", \"cool\", \"brilliant\", \"wonderful\", \n",
    "                              \"well\", \"amazing\", \"worth\", \"sweet\", \"enjoyable\"],\n",
    "                      'orange': [\"boring\", \"bad\", \"waste\", \"dumb\", \"annoying\", \"stupid\"],\n",
    "                      'red': ['tell', 'told', 'said', 'say', 'says', 'tells', 'goes', 'go', 'went']\n",
    "                     }\n",
    "\n",
    "    plt.figure(figsize=(7,7))\n",
    "    plt.suptitle(title)\n",
    "    for color, words in visualize_words.items():\n",
    "        points = np.array([wordvec_embedding[data.index_by_token[w]] for w in words])\n",
    "        for i, word in enumerate(words):\n",
    "            plt.text(points[i, 0], points[i, 1], word, color=color,horizontalalignment='center')\n",
    "        plt.scatter(points[:, 0], points[:, 1], c=color, alpha=0.3, s=0.5)\n",
    "\n",
    "visualize_vectors(untrained_input_vectors, untrained_output_vectors, \"Untrained word vectors\")\n",
    "visualize_vectors(trained_input_vectors, trained_output_vectors, \"Trained word vectors\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
